%=====================
%this file solves the baseline T1
%uses the parallel toolbox 
%=====================
%with TDA and curent income taxes

clearvars;
clc;
close all;
global vfuture AMvalues BMvalues NCini crit e r b eJr s climit

%=====================
%first define environment values
%=====================
JR=16; %retirement period 
r=0.10; %per period pre-tax interest rate
b=1; %no discounting
%endowment for workers
e=[67
74
79
85
89
93
96
97
98
98
97
95
93
90
86
];
eJr=51; %endowment for retirees

climit=47.4; %consumption floor
s=9/10; %survival rates after retirement
%maximum contribution limit
qbar=zeros(JR-1,1); %only workers contribute 
qbar(1:10)=36.0;
qbar(11:JR-1)=48.0;



%=====================
%Define state space and choice/value function
%=====================

NA= 50; 
NB=100;
Avector=linspace(0,10,NA)';
Avector=Avector.^2; %more grids for small values
Bvector=linspace(0,30,NB)';
Bvector=Bvector.^2;
AMvalues=repmat(Avector',NB,1); %the matrix of values
BMvalues=repmat(Bvector,1,NA);


%=====================
%Parameters used for solving the model
%=====================
%for retirees
critv=100;
diffv=1.0e-8;
crit=0.0004;
NCini=5;
NQini=5;
critq=0.0004;
i=1;

vfunc=cell(JR,1); %the cell argument represents period
cfunc=cell(JR,1);
qfunc=cell(JR,1);
%the row vector is TDA balance, and column vector is RA balance
for j=1:JR
    vfunc{j}=zeros(NB,NA);
    cfunc{j}=zeros(NB,NA);
    qfunc{j}=zeros(NB,NA);
end
vrfunc=zeros(NB,NA);
crfunc=zeros(NB,NA);
qrfunc=zeros(NB,NA);

%solve retiree's value function and decision rules
while (critv>diffv)
    vrfunc_old=vrfunc;
    vfuture=vrfunc;
    ib=1;
    parfor ia=1:NA
            qg=0;
            [vmax_c, cmax_c]=search_cr_parfor(Avector(ia),Bvector(ib),qg,vfuture, AMvalues, BMvalues, NCini, crit, r, b, eJr, s, climit);
            vrfunc(ib,ia)=vmax_c;
            crfunc(ib,ia)=cmax_c;
            qrfunc(ib,ia)=0;
    end %ia

    

        
    for ib=2:NB
        parfor ia=1:NA 

            qmax=0;
            qmin=-(1+r)*Bvector(ib);
            vfunc_temp=zeros(NQini,1);
            cfunc_temp=zeros(NQini,1);
            qvector=linspace(qmin,qmax,NQini);
            for iq=1:NQini
            qg=qvector(iq);              
            [vmax_c, cmax_c]=search_cr_parfor(Avector(ia),Bvector(ib),qg,vfuture, AMvalues, BMvalues, NCini, crit, r, b, eJr, s, climit);
            vfunc_temp(iq)=vmax_c;
            cfunc_temp(iq)=cmax_c;
            end %iq
            [vtemp_q, qtemp_g]=max(vfunc_temp);
            vmax_q=vtemp_q;
            qmax_q=qvector(qtemp_g);
            cmax_q=cfunc_temp(qtemp_g);            
            diffq=(qmax-qmin)/(NQini-1);
           while diffq>critq
            diffq=diffq/2;          
            if ((qmax_q>qmin))
                qg=qmax_q-diffq;
                [vtemp,ctemp]=search_cr_parfor(Avector(ia),Bvector(ib),qg,vfuture, AMvalues, BMvalues, NCini, crit, r, b, eJr, s, climit);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                    continue
                end %if
            end %while internal
            if ((qmax_q<qmax))
                qg=qmax_q+diffq;
                [vtemp,ctemp]=search_cr_parfor(Avector(ia),Bvector(ib),qg,vfuture, AMvalues, BMvalues, NCini, crit, r, b, eJr, s, climit)
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                end %if
            end %while internal
            end %while
            vrfunc(ib,ia)=vmax_q;
            crfunc(ib,ia)=cmax_q;
            qrfunc(ib,ia)=qmax_q;
        end %ia
    end %ib
    critv=max(max(abs(vrfunc_old-vrfunc)))
    i=i+1
end

%solve value function and decision rules for working agents
    vfunc{JR}=vrfunc;
    cfunc{JR}=crfunc;
    qfunc{JR}=qrfunc;
for j=JR-1:-1:1 %cannot withdraw
    j
    vfuture=vfunc{j+1};
    for ib=1:NB
        for ia=1:NA
             fun=@(x) e(j)+(1+r)*Avector(ia)-x-tax(e(j)+(1+r)*Avector(ia)-x);
             qmax=fzero(fun,0);            
             qmax=min(e(j),qmax);
             qmax=min(qbar(j),qmax);
            qmin=0;
            vfunc_temp=zeros(NQini,1);
            cfunc_temp=zeros(NQini,1);
            qvector=linspace(qmin,qmax,NQini);
            for iq=1:NQini
            qg=qvector(iq);              
            [vmax_c, cmax_c]=search_c(j,Avector(ia),Bvector(ib),qg);
            vfunc_temp(iq)=vmax_c;
            cfunc_temp(iq)=cmax_c;
            end %iq
            [vtemp_q, qtemp_g]=max(vfunc_temp);
            vmax_q=vtemp_q;
            qmax_q=qvector(qtemp_g);
            cmax_q=cfunc_temp(qtemp_g);            
            diffq=(qmax-qmin)/(NQini-1);
           while diffq>critq
            diffq=diffq/2;          
            if ((qmax_q>qmin))
                qg=qmax_q-diffq;
                [vtemp,ctemp]=search_c(j,Avector(ia),Bvector(ib),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                    continue
                end %if
            end %while internal
            if ((qmax_q<qmax))
                qg=qmax_q+diffq;
                [vtemp,ctemp]=search_c(j,Avector(ia),Bvector(ib),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                end %if
            end %while internal
            end %while
            vfunc{j}(ib,ia)=vmax_q;
            cfunc{j}(ib,ia)=cmax_q;
            qfunc{j}(ib,ia)=qmax_q;
        end %ia
    end %ib
end %j


%=====================
%simulate optimal decisions
%decisions are rounded to 1 dollar = 0.001 in thousands
%=====================
JN=100; %total number of periods only used for simulation
vopt=zeros(JN,1); %optimal value
copt=zeros(JN,1); %consumption
qopt=zeros(JN,1); %contribution/withdraw
aopt=zeros(JN+1,1); %RA balance
bopt=zeros(JN+1,1); %TDA balance
taxopt=zeros(JN,1); %tax payments
qmaxopt=zeros(JR-1,1); %maximum possible TDA contribution
taxableopt=zeros(JN,1); %taxable income
%workers
for j=1:JR-1
             vfuture=vfunc{j+1}; 
             fun=@(x) e(j)+(1+r)*aopt(j)-x-tax(e(j)+(1+r)*aopt(j)-x);
             
             qmax=fzero(fun,0);            
             qmax=min(e(j),qmax);
             qmax=min(qbar(j),qmax);
            qmaxopt(j)=qmax;
            qmin=0;
            vfunc_temp=zeros(NQini,1);
            cfunc_temp=zeros(NQini,1);
            qvector=linspace(qmin,qmax,NQini);
            for iq=1:NQini
            qg=qvector(iq);              
            [vmax_c, cmax_c]=search_c(j,aopt(j),bopt(j),qg);
            vfunc_temp(iq)=vmax_c;
            cfunc_temp(iq)=cmax_c;
            end %iq
            [vtemp_q, qtemp_g]=max(vfunc_temp);
            vmax_q=vtemp_q;
            qmax_q=qvector(qtemp_g);
            cmax_q=cfunc_temp(qtemp_g);            
            diffq=(qmax-qmin)/(NQini-1);
           while diffq>critq
            diffq=diffq/2;          
            if ((qmax_q>qmin))
                qg=qmax_q-diffq;
                [vtemp,ctemp]=search_c(j,aopt(j),bopt(j),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                    continue
                end %if
            end %while internal
            if ((qmax_q<qmax))
                qg=qmax_q+diffq;
                [vtemp,ctemp]=search_c(j,aopt(j),bopt(j),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                end %if
            end %while internal
            end %while                    

    qopt(j)=round(qmax_q,3);
    copt(j)=min(round(cmax_q,3),floor((e(j)+(1+r)*aopt(j)-tax(e(j)+r*aopt(j)-qopt(j))-qopt(j))*1000)/1000);
    bopt(j+1)=bopt(j)*(1+r)+qopt(j);
    aopt(j+1)=e(j)+(1+r)*aopt(j)-tax(e(j)+r*aopt(j)-qopt(j))-qopt(j)-copt(j); 
    taxopt(j)=tax(e(j)+r*aopt(j)-qopt(j));
    taxableopt(j)=e(j)+r*aopt(j)-qopt(j);
end
%retirees
for j=JR:JN
 
            vfuture=vrfunc;    
            qmax=0;
            qmin=-(1+r)*bopt(j);
            vfunc_temp=zeros(NQini,1);
            cfunc_temp=zeros(NQini,1);
            qvector=linspace(qmin,qmax,NQini);
            for iq=1:NQini
            qg=qvector(iq);              
            [vmax_c, cmax_c]=search_cr(aopt(j),bopt(j),qg);
            vfunc_temp(iq)=vmax_c;
            cfunc_temp(iq)=cmax_c;
            end %iq
            [vtemp_q, qtemp_g]=max(vfunc_temp);
            vmax_q=vtemp_q;
            qmax_q=qvector(qtemp_g);
            cmax_q=cfunc_temp(qtemp_g);            
            diffq=(qmax-qmin)/(NQini-1);
           while diffq>critq
            diffq=diffq/2;          
            if ((qmax_q>qmin))
                qg=qmax_q-diffq;
                [vtemp,ctemp]=search_cr(aopt(j),bopt(j),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                    continue
                end %if
            end %while internal
            if ((qmax_q<qmax))
                qg=qmax_q+diffq;
                [vtemp,ctemp]=search_cr(aopt(j),bopt(j),qg);
                if (vtemp>vmax_q) 
                    vmax_q=vtemp;
                    cmax_q=ctemp;
                    qmax_q=qg;
                end %if
            end %while internal
            end %while                    
    qopt(j)=round(qmax_q,3);
    copt(j)=min(round(cmax_q,3),floor((eJr+(1+r)*aopt(j)-tax(eJr+r*aopt(j)-qopt(j))-qopt(j))*1000)/1000);
   bopt(j+1)=bopt(j)*(1+r)+qopt(j);
    aopt(j+1)=eJr+(1+r)*aopt(j)-tax(eJr+r*aopt(j)-qopt(j))-qopt(j)-copt(j);  
    taxopt(j)=tax(eJr+r*aopt(j)-qopt(j));
    taxableopt(j)=eJr+r*aopt(j)-qopt(j);
end

%utility if adopting the optimal decisions
paymentopt=zeros(JN,1);
j=1;
paymentopt(j)=b^(j-1)*uc(copt(j));
for j=2:JN
paymentopt(j)=paymentopt(j-1)+b^(j-1)*uc(copt(j));
end

taxrateopt=zeros(JN,1);
for j=1:JR-1
    taxrateopt(j)=tax(e(j)+r*aopt(j)-qopt(j));
end
for j=JR:JN
    taxrateopt(j)=tax(eJr+r*aopt(j)-qopt(j));
end

%utility if adopting the hand to mouth strategy
paymenthandtomouth=zeros(JN,1);
j=1;
paymenthandtomouth(j)=b^(j-1)*uc(e(j)-tax(e(j)));
for j=2:JR-1
paymenthandtomouth(j)=paymenthandtomouth(j-1)+b^(j-1)*uc(e(j)-tax(e(j)));
end
for j=JR:JN
paymenthandtomouth(j)=paymenthandtomouth(j-1)+b^(j-1)*uc(eJr-tax(eJr));
end

chandtomouth=zeros(JN,1);
j=1;
chandtomouth(j)=e(j)-tax(e(j));
for j=2:JR-1
chandtomouth(j)=(e(j)-tax(e(j)));
end
for j=JR:JN
chandtomouth(j)=(eJr-tax(eJr));
end


prob_d=zeros(JN+1,1); %the probability of decease in period j
for j=JR+1:JN
    prob_d(j)=prob_d(j-1)+(1-s)*(1-prob_d(j-1));
end
prob_d(JN+1)=1; %maximum age of JN



%several statistics
fprintf('expected payment hand to mouth')
sum((prob_d(2:JN+1)-prob_d(1:JN)).*paymenthandtomouth)
fprintf('expected payment optimal')
sum((prob_d(2:JN+1)-prob_d(1:JN)).*paymentopt)
fprintf('regular asset at retirement (before tax)')
aopt(JR)
fprintf('TDA asset at retirement (before tax)')
bopt(JR)
fprintf('Average consumption before retirement')
mean(copt(1:JR-1))
fprintf('Consumption at first period of retirement')
copt(JR)
fprintf('Expected tax revenue')
sum((1-prob_d(1:JN)).*taxopt)
fprintf('total savings (before retirement')
sum(qopt(1:JR-1)+aopt(2:JR)+(1+r)*aopt(1:JR-1))


%relabel variables with _T1
assetopt=bopt+aopt;
copt_T1=copt;
qopt_T1=qopt;
aopt_T1=aopt;
bopt_T1=bopt;
paymentopt_T1=paymentopt;
taxableopt_T1=taxableopt;
taxopt_T1=taxopt;
sopt_T1=zeros(JN,1);
assetopt_T1=assetopt;

for j=1:JR-1
sopt_T1(j)=e(j)-taxopt_T1(j)-copt_T1(j);
end

for j=JR:JN
sopt_T1(j)=eJr-taxopt_T1(j)-copt_T1(j);    
end
%save key decision rules
save T1opt.mat copt_T1 qopt_T1 aopt_T1 bopt_T1 paymentopt_T1 taxableopt_T1 taxopt_T1 sopt_T1 assetopt_T1 vrfunc crfunc qrfunc vfunc qfunc cfunc

fprintf('T1: Selected aggregate statistics')
[sum((prob_d(2:JN+1)-prob_d(1:JN)).*paymentopt); paymentopt(15); sum((prob_d(2:JN+1)-prob_d(1:JN)).*paymentopt)-paymentopt(15);...
    0; aopt(JR); bopt(JR);  mean(copt(1:JR-1));  mean(copt(16:25)); copt(16); sum((1-prob_d(1:JN)).*taxopt); sum(sopt_T1(1:JR-1))]

fprintf('T1: decision rules\n')
fprintf('1.consumption   2.cumulative earnings   3.TDA contribution 4.other savings   5. TDA balance 6.RA balance  5.taxable income 6.taxes  7.survival prob.')
T1=[copt_T1(1:JN) paymentopt_T1(1:JN) qopt_T1(1:JN) sopt_T1(1:JN)-qopt_T1(1:JN)  bopt_T1(1:JN)  aopt_T1(1:JN)  taxableopt_T1(1:JN) taxopt_T1(1:JN)  (1-prob_d(1:JN))]
